# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import re
import tempfile
import unittest
from itertools import product

import numpy as np

import mindspore as ms
from mindspore import mint

from mindone.diffusers import AutoencoderKL, DDIMScheduler, LCMScheduler, UNet2DConditionModel
from mindone.diffusers.utils import logging
from mindone.diffusers.utils.testing_utils import CaptureLogger, floats_tensor, require_peft_version_greater
from mindone.peft import LoraConfig
from mindone.peft.tuners.tuners_utils import BaseTunerLayer
from mindone.peft.utils import get_peft_model_state_dict


def state_dicts_almost_equal(sd1, sd2):
    sd1 = dict(sorted(sd1.items()))
    sd2 = dict(sorted(sd2.items()))

    models_are_equal = True
    for ten1, ten2 in zip(sd1.values(), sd2.values()):
        if (ten1 - ten2).abs().max() > 1e-3:
            models_are_equal = False

    return models_are_equal


def check_if_lora_correctly_set(model) -> bool:
    """
    Checks if the LoRA layers are correctly set with peft
    """
    for _, module in model.cells_and_names():
        if isinstance(module, BaseTunerLayer):
            return True
    return False


POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]


def determine_attention_kwargs_name(pipeline_class):
    call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()

    # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
    for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
        if possible_attention_kwargs in call_signature_keys:
            attention_kwargs_name = possible_attention_kwargs
            break
    assert attention_kwargs_name is not None
    return attention_kwargs_name


class PeftLoraLoaderMixinTests:
    pipeline_class = None

    scheduler_cls = None
    scheduler_kwargs = None
    scheduler_classes = [DDIMScheduler, LCMScheduler]

    has_two_text_encoders = False
    has_three_text_encoders = False
    text_encoder_cls, text_encoder_id, text_encoder_subfolder, text_encoder_revision = None, None, "", "main"
    text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder, text_encoder_2_revision = None, None, "", "main"
    text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder, text_encoder_3_revision = None, None, "", "main"
    tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
    tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
    tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""

    unet_kwargs = None
    transformer_cls = None
    transformer_kwargs = None
    vae_cls = AutoencoderKL
    vae_kwargs = None

    text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
    denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]

    def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
        if self.unet_kwargs and self.transformer_kwargs:
            raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
        if self.has_two_text_encoders and self.has_three_text_encoders:
            raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")

        scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
        rank = 4
        lora_alpha = rank if lora_alpha is None else lora_alpha

        ms.manual_seed(0)
        if self.unet_kwargs is not None:
            unet = UNet2DConditionModel(**self.unet_kwargs)
        else:
            transformer = self.transformer_cls(**self.transformer_kwargs)

        scheduler = scheduler_cls(**self.scheduler_kwargs)

        ms.manual_seed(0)
        vae = self.vae_cls(**self.vae_kwargs)

        text_encoder = self.text_encoder_cls.from_pretrained(
            self.text_encoder_id, subfolder=self.text_encoder_subfolder, revision=self.text_encoder_revision
        )
        tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)

        if self.text_encoder_2_cls is not None:
            text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
                self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder, revision=self.text_encoder_2_revision
            )
            tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
                self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
            )

        if self.text_encoder_3_cls is not None:
            text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
                self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder, revision=self.text_encoder_3_revision
            )
            tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
                self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
            )

        text_lora_config = LoraConfig(
            r=rank,
            lora_alpha=lora_alpha,
            target_modules=self.text_encoder_target_modules,
            init_lora_weights=False,
            # use_dora=use_dora,
        )

        denoiser_lora_config = LoraConfig(
            r=rank,
            lora_alpha=lora_alpha,
            target_modules=self.denoiser_target_modules,
            init_lora_weights=False,
            # use_dora=use_dora,
        )

        pipeline_components = {
            "scheduler": scheduler,
            "vae": vae,
            "text_encoder": text_encoder,
            "tokenizer": tokenizer,
        }
        # Denoiser
        if self.unet_kwargs is not None:
            pipeline_components.update({"unet": unet})
        elif self.transformer_kwargs is not None:
            pipeline_components.update({"transformer": transformer})

        # Remaining text encoders.
        if self.text_encoder_2_cls is not None:
            pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
        if self.text_encoder_3_cls is not None:
            pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})

        # Remaining stuff
        init_params = inspect.signature(self.pipeline_class.__init__).parameters
        if "safety_checker" in init_params:
            pipeline_components.update({"safety_checker": None})
        if "feature_extractor" in init_params:
            pipeline_components.update({"feature_extractor": None})
        if "image_encoder" in init_params:
            pipeline_components.update({"image_encoder": None})

        return pipeline_components, text_lora_config, denoiser_lora_config

    @property
    def output_shape(self):
        raise NotImplementedError

    def get_dummy_inputs(self, with_generator=True):
        batch_size = 1
        sequence_length = 10
        num_channels = 4
        sizes = (32, 32)

        generator = np.random.default_rng(0)
        noise = floats_tensor((batch_size, num_channels) + sizes)
        input_ids = mint.randint(1, sequence_length, (batch_size, sequence_length), generator=ms.manual_seed(0))

        pipeline_inputs = {
            "prompt": "A painting of a squirrel eating a burger",
            "num_inference_steps": 5,
            "guidance_scale": 6.0,
            "output_type": "np",
        }
        if with_generator:
            pipeline_inputs.update({"generator": generator})

        return noise, input_ids, pipeline_inputs

    # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
    def get_dummy_tokens(self):
        max_seq_length = 77

        inputs = mint.randint(2, 56, (1, max_seq_length), generator=ms.manual_seed(0))

        prepared_inputs = {}
        prepared_inputs["input_ids"] = inputs
        return prepared_inputs

    def _get_lora_state_dicts(self, modules_to_save):
        state_dicts = {}
        for module_name, module in modules_to_save.items():
            if module is not None:
                state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
        return state_dicts

    def _get_modules_to_save(self, pipe, has_denoiser=False):
        modules_to_save = {}
        lora_loadable_modules = self.pipeline_class._lora_loadable_modules

        if (
            "text_encoder" in lora_loadable_modules
            and hasattr(pipe, "text_encoder")
            and getattr(pipe.text_encoder, "peft_config", None) is not None
        ):
            modules_to_save["text_encoder"] = pipe.text_encoder

        if (
            "text_encoder_2" in lora_loadable_modules
            and hasattr(pipe, "text_encoder_2")
            and getattr(pipe.text_encoder_2, "peft_config", None) is not None
        ):
            modules_to_save["text_encoder_2"] = pipe.text_encoder_2

        if has_denoiser:
            if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
                modules_to_save["unet"] = pipe.unet

            if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
                modules_to_save["transformer"] = pipe.transformer

        return modules_to_save

    def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
        if text_lora_config is not None:
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

        if denoiser_lora_config is not None:
            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
        else:
            denoiser = None

        if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
            if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                )
        return pipe, denoiser

    def test_simple_inference(self):
        """
        Tests a simple inference and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)

            _, _, inputs = self.get_dummy_inputs()
            output_no_lora = pipe(**inputs)[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

    def test_simple_inference_with_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            output_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

    def test_simple_inference_with_text_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + scale argument
        and makes sure it works as expected
        """
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            output_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]

            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

    def test_simple_inference_with_text_lora_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            pipe.fuse_lora()
            # Fusing should still keep the LoRA layers
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            ouput_fused = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertFalse(
                np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder, then unloads the lora weights
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )

            ouput_unloaded = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA.
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            images_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)

                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt"))

            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

            images_lora_from_pretrained = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

    def test_simple_inference_with_partial_text_lora(self):
        """
        Tests a simple inference with lora attached on the text encoder
        with different ranks and some adapters removed
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, _, _ = self.get_dummy_components(scheduler_cls)
            # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
            text_lora_config = LoraConfig(
                r=4,
                rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
                lora_alpha=4,
                target_modules=self.text_encoder_target_modules,
                init_lora_weights=False,
                # use_dora=False,
            )
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)

            state_dict = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
                # supports missing layers (PR#8324).
                state_dict = {
                    f"text_encoder.{module_name}": param
                    for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
                    if "text_model.encoder.layers.4" not in module_name
                }

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    state_dict.update(
                        {
                            f"text_encoder_2.{module_name}": param
                            for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
                            if "text_model.encoder.layers.4" not in module_name
                        }
                    )

            output_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            # Unload lora and load it back using the pipe.load_lora_weights machinery
            pipe.unload_lora_weights()
            pipe.load_lora_weights(state_dict)

            output_partial_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
                "Removing adapters should change the output",
            )

    def test_simple_inference_save_pretrained_with_text_lora(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
            images_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                pipe.save_pretrained(tmpdirname)

                pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
                    "Lora not correctly set in text encoder",
                )

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
                        "Lora not correctly set in text encoder 2",
                    )

            images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

    def test_simple_inference_with_text_denoiser_lora_save_load(self):
        """
        Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            images_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt")))
                pipe.unload_lora_weights()
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt"))

            for module_name, module in modules_to_save.items():
                self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

            images_lora_from_pretrained = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                "Loading from saved checkpoints should give same results.",
            )

    def test_simple_inference_with_text_denoiser_lora_and_scale(self):
        """
        Tests a simple inference with lora attached on the text encoder + Unet + scale argument
        and makes sure it works as expected
        """
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            output_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
            )

            attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
            output_lora_scale = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]

            self.assertTrue(
                not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

            attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
            output_lora_0_scale = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
                "Lora + 0 scale should lead to same result as no LoRA",
            )

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
                    "The scaling parameter has not been correctly restored!",
                )

    def test_simple_inference_with_text_lora_denoiser_fused(self):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)

            # Fusing should still keep the LoRA layers
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            output_fused = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertFalse(
                np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
            )

    def test_simple_inference_with_text_denoiser_lora_unloaded(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            pipe.unload_lora_weights()
            # unloading should remove the LoRA layers
            self.assertFalse(
                check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
            )
            self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertFalse(
                        check_if_lora_correctly_set(pipe.text_encoder_2),
                        "Lora not correctly unloaded in text encoder 2",
                    )

            output_unloaded = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
                "Fused lora should change the output",
            )

    def test_simple_inference_with_text_denoiser_lora_unfused(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
            self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
            output_fused_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
            output_unfused_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            # unloading should remove the LoRA layers
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )

            # Fuse and unfuse should lead to the same results
            self.assertTrue(
                np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
            )

    def test_simple_inference_with_text_denoiser_multi_adapter(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
        np.random.seed(0)
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            pipe.set_adapters("adapter-1")
            output_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )

            pipe.set_adapters("adapter-2")
            output_adapter_2 = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"])
            output_adapter_mixed = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertFalse(
                np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter outputs should be different.",
            )

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
            output_disabled = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

    def test_wrong_adapter_name_raises_error(self):
        adapter_name = "adapter-1"

        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name)

        with self.assertRaises(ValueError) as err_context:
            pipe.set_adapters("test")

        self.assertTrue("not in the list of present adapters" in str(err_context.exception))

        # test this works.
        pipe.set_adapters(adapter_name)
        _ = pipe(**inputs, generator=np.random.default_rng(0))[0]

    def test_multiple_wrong_adapter_name_raises_error(self):
        adapter_name = "adapter-1"
        scheduler_cls = self.scheduler_classes[0]
        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name)

        scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
        logger = logging.get_logger("mindone.diffusers.loaders.lora_base")
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)

        wrong_components = sorted(set(scale_with_wrong_components.keys()))
        msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
        self.assertTrue(msg in str(cap_logger.out))

        # test this works.
        pipe.set_adapters(adapter_name)
        _ = pipe(**inputs, generator=np.random.default_rng(0))[0]

    def test_simple_inference_with_text_denoiser_block_scale(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        one adapter and set different weights for different blocks (i.e. block lora)
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
            pipe.set_adapters("adapter-1", weights_1)
            output_weights_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            weights_2 = {"unet": {"up": 5}}
            pipe.set_adapters("adapter-1", weights_2)
            output_weights_2 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertFalse(
                np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
                "LoRA weights 1 and 2 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 1 should give different results",
            )
            self.assertFalse(
                np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
                "No adapter and LoRA weights 2 should give different results",
            )

            pipe.disable_lora()
            output_disabled = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

    def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set different weights for different blocks (i.e. block lora)
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
            scales_2 = {"unet": {"down": 5, "mid": 5}}

            pipe.set_adapters("adapter-1", scales_1)
            output_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters("adapter-2", scales_2)
            output_adapter_2 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
            output_adapter_mixed = pipe(**inputs, generator=np.random.default_rng(0))[0]

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.disable_lora()
            output_disabled = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            # a mismatching number of adapter_names and adapter_weights should raise an error
            with self.assertRaises(ValueError):
                pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])

    def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
        """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""

        def updown_options(blocks_with_tf, layers_per_block, value):
            """
            Generate every possible combination for how a lora weight dict for the up/down part can be.
            E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
            """
            num_val = value
            list_val = [value] * layers_per_block

            node_opts = [None, num_val, list_val]
            node_opts_foreach_block = [node_opts] * len(blocks_with_tf)

            updown_opts = [num_val]
            for nodes in product(*node_opts_foreach_block):
                if all(n is None for n in nodes):
                    continue
                opt = {}
                for b, n in zip(blocks_with_tf, nodes):
                    if n is not None:
                        opt["block_" + str(b)] = n
                updown_opts.append(opt)
            return updown_opts

        def all_possible_dict_opts(unet, value):
            """
            Generate every possible combination for how a lora weight dict can be.
            E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
            """

            down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
            up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]

            layers_per_block = unet.config.layers_per_block

            text_encoder_opts = [None, value]
            text_encoder_2_opts = [None, value]
            mid_opts = [None, value]
            down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
            up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)

            opts = []

            for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
                if all(o is None for o in (t1, t2, d, m, u)):
                    continue
                opt = {}
                if t1 is not None:
                    opt["text_encoder"] = t1
                if t2 is not None:
                    opt["text_encoder_2"] = t2
                if all(o is None for o in (d, m, u)):
                    # no unet scaling
                    continue
                opt["unet"] = {}
                if d is not None:
                    opt["unet"]["down"] = d
                if m is not None:
                    opt["unet"]["mid"] = m
                if u is not None:
                    opt["unet"]["up"] = u
                opts.append(opt)

            return opts

        components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config, "adapter-1")

        if self.has_two_text_encoders or self.has_three_text_encoders:
            lora_loadable_components = self.pipeline_class._lora_loadable_modules
            if "text_encoder_2" in lora_loadable_components:
                pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")

        for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
            # test if lora block scales can be set with this scale_dict
            if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
                del scale_dict["text_encoder_2"]

            pipe.set_adapters("adapter-1", scale_dict)  # test will fail if this line throws an error

    def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set/delete them
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            pipe.set_adapters("adapter-1")
            output_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters("adapter-2")
            output_adapter_2 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters(["adapter-1", "adapter-2"])
            output_adapter_mixed = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.delete_adapters("adapter-1")
            output_deleted_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            pipe.delete_adapters("adapter-2")
            output_deleted_adapters = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            pipe.set_adapters(["adapter-1", "adapter-2"])
            pipe.delete_adapters(["adapter-1", "adapter-2"])

            output_deleted_adapters = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

    def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, attaches
        multiple adapters and set them
        """
        np.random.seed(0)
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            pipe.set_adapters("adapter-1")
            output_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters("adapter-2")
            output_adapter_2 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters(["adapter-1", "adapter-2"])
            output_adapter_mixed = pipe(**inputs, generator=np.random.default_rng(0))[0]

            # Fuse and unfuse should lead to the same results
            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
                "Adapter 1 and 2 should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 1 and mixed adapters should give different results",
            )

            self.assertFalse(
                np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Adapter 2 and mixed adapters should give different results",
            )

            pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
            output_adapter_mixed_weighted = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertFalse(
                np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
                "Weighted adapter and mixed adapter should give different results",
            )

            pipe.disable_lora()
            output_disabled = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
                "output with no lora and output with lora disabled should give same results",
            )

    def test_lora_fuse_nan(self):
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            # corrupt one LoRA weight with `inf` values
            if self.unet_kwargs:
                pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
                    "inf"
                )
            else:
                named_modules = [name for name, _ in pipe.transformer.cells_and_names()]
                possible_tower_names = [
                    "transformer_blocks",
                    "blocks",
                    "joint_transformer_blocks",
                    "single_transformer_blocks",
                ]
                filtered_tower_names = [
                    tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
                ]
                if len(filtered_tower_names) == 0:
                    reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
                    raise ValueError(reason)
                for tower_name in filtered_tower_names:
                    transformer_tower = getattr(pipe.transformer, tower_name)
                    has_attn1 = any("attn1" in name for name in named_modules)
                    if has_attn1:
                        transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
                    else:
                        transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")

            # with `safe_fusing=True` we should see an Error
            with self.assertRaises(ValueError):
                pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

            # without we should not see an error, but every image will be black
            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
            out = pipe(**inputs)[0]

            self.assertTrue(np.isnan(out).all())

    def test_get_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-1"])

            pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")

            adapter_names = pipe.get_active_adapters()
            self.assertListEqual(adapter_names, ["adapter-2"])

            pipe.set_adapters(["adapter-1", "adapter-2"])
            self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])

    def test_get_list_adapters(self):
        """
        Tests a simple usecase where we attach multiple adapters and check if the results
        are the expected results
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)

            # 1.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                dicts_to_be_checked = {"text_encoder": ["adapter-1"]}

            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"unet": ["adapter-1"]})
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
                dicts_to_be_checked.update({"transformer": ["adapter-1"]})

            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 2.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})

            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

            # 3.
            pipe.set_adapters(["adapter-1", "adapter-2"])

            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

            if self.unet_kwargs is not None:
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
            else:
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})

            self.assertDictEqual(
                pipe.get_list_adapters(),
                dicts_to_be_checked,
            )

            # 4.
            dicts_to_be_checked = {}
            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}

            if self.unet_kwargs is not None:
                pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
                dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
            else:
                pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
                dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})

            self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)

    @require_peft_version_greater(peft_version="0.6.2")
    def test_simple_inference_with_text_lora_denoiser_fused_multi(
        self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
    ):
        """
        Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
        and makes sure it works as expected - with unet and multi-adapter case
        """
        np.random.seed(0)
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )
                pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
            denoiser.add_adapter(denoiser_lora_config, "adapter-2")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )
                    pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")

            # set them to multi-adapter inference mode
            pipe.set_adapters(["adapter-1", "adapter-2"])
            outputs_all_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.set_adapters(["adapter-1"])
            outputs_lora_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
            self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

            # Fusing should still keep the LoRA layers so output should remain the same
            outputs_lora_1_fused = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertTrue(
                np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
            )

            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")

            self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
                    )

            pipe.fuse_lora(
                components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
            )
            self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

            # Fusing should still keep the LoRA layers
            output_all_lora_fused = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
                "Fused lora should not change the output",
            )
            pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
            self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

    def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
        np.random.seed(0)
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

        for lora_scale in [1.0, 0.8]:
            for scheduler_cls in self.scheduler_classes:
                components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
                pipe = self.pipeline_class(**components)
                pipe.set_progress_bar_config(disable=None)
                _, _, inputs = self.get_dummy_inputs(with_generator=False)

                output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
                self.assertTrue(output_no_lora.shape == self.output_shape)

                if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                    pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                    )

                denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
                denoiser.add_adapter(denoiser_lora_config, "adapter-1")
                self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

                if self.has_two_text_encoders or self.has_three_text_encoders:
                    lora_loadable_components = self.pipeline_class._lora_loadable_modules
                    if "text_encoder_2" in lora_loadable_components:
                        pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
                        self.assertTrue(
                            check_if_lora_correctly_set(pipe.text_encoder_2),
                            "Lora not correctly set in text encoder 2",
                        )

                pipe.set_adapters(["adapter-1"])
                attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
                outputs_lora_1 = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]

                pipe.fuse_lora(
                    components=self.pipeline_class._lora_loadable_modules,
                    adapter_names=["adapter-1"],
                    lora_scale=lora_scale,
                )
                self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")

                outputs_lora_1_fused = pipe(**inputs, generator=np.random.default_rng(0))[0]

                self.assertTrue(
                    np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
                    "Fused lora should not change the output",
                )
                self.assertFalse(
                    np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
                    "LoRA should change the output",
                )

    @require_peft_version_greater(peft_version="0.9.0")
    def test_simple_inference_with_dora(self):
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_dora_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_dora_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            output_dora_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            self.assertFalse(
                np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
                "DoRA lora should change the output",
            )

    def test_missing_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt")))
            state_dict = ms.load_checkpoint(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt"))

        # To make things dynamic since we cannot settle with a single key for all the models where we
        # offer PEFT support.
        missing_key = [k for k in state_dict if "lora_A" in k][0]
        del state_dict[missing_key]

        logger = logging.get_logger("mindone.diffusers.utils.peft_utils")
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        # Since the missing key won't contain the adapter name ("default_0").
        # Also strip out the component prefix (such as "unet." from `missing_key`).
        component = list({k.split(".")[0] for k in state_dict})[0]
        self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))

    def test_unexpected_keys_warning(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)

        denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
        denoiser.add_adapter(denoiser_lora_config)
        self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

        with tempfile.TemporaryDirectory() as tmpdirname:
            modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
            lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
            self.pipeline_class.save_lora_weights(
                save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
            )
            pipe.unload_lora_weights()
            self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt")))
            state_dict = ms.load_checkpoint(os.path.join(tmpdirname, "pytorch_lora_weights.ckpt"))

        unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
        state_dict[unexpected_key] = ms.Parameter(ms.tensor(1.0), name=unexpected_key)

        logger = logging.get_logger("mindone.diffusers.utils.peft_utils")
        logger.setLevel(30)
        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(state_dict)

        self.assertTrue(".diffusers_cat" in cap_logger.out)

    @unittest.skip("This is failing for now - need to investigate")
    def test_simple_inference_with_text_denoiser_lora_unfused_mindspore_jit(self):
        """
        Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
        and makes sure it works as expected
        """
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            pipe.unet.construct = ms.jit(pipe.unet.construct, fullgraph=True)
            pipe.text_encoder.construct = ms.jit(pipe.text_encoder.construct, fullgraph=True)

            if self.has_two_text_encoders or self.has_three_text_encoders:
                pipe.text_encoder_2.construct = ms.jit(pipe.text_encoder_2.construct, fullgraph=True)

            # Just makes sure it works.
            _ = pipe(**inputs, generator=np.random.default_rng(0))[0]

    # def test_modify_padding_mode(self):
    #     def set_pad_mode(network, mode="circular"):
    #         for _, module in network.cells_and_names():
    #             if isinstance(module, mint.nn.Conv2d):
    #                 module.padding_mode = mode

    #     for scheduler_cls in self.scheduler_classes:
    #         components, _, _ = self.get_dummy_components(scheduler_cls)
    #         pipe = self.pipeline_class(**components)
    #         pipe.set_progress_bar_config(disable=None)
    #         _pad_mode = "circular"
    #         set_pad_mode(pipe.vae, _pad_mode)
    #         set_pad_mode(pipe.unet, _pad_mode)

    #         _, _, inputs = self.get_dummy_inputs()
    #         _ = pipe(**inputs)[0]

    def test_logs_info_when_no_lora_keys_found(self):
        scheduler_cls = self.scheduler_classes[0]
        # Skip text encoder check for now as that is handled with `transformers`.
        components, _, _ = self.get_dummy_components(scheduler_cls)
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)

        _, _, inputs = self.get_dummy_inputs(with_generator=False)
        original_out = pipe(**inputs, generator=np.random.default_rng(0))[0]

        no_op_state_dict = {
            "lora_foo": ms.Parameter(ms.tensor(2.0), name="lora_foo"),
            "lora_bar": ms.Parameter(ms.tensor(3.0), name="lora_bar"),
        }
        logger = logging.get_logger("mindone.diffusers.loaders.peft")
        logger.setLevel(logging.WARNING)

        with CaptureLogger(logger) as cap_logger:
            pipe.load_lora_weights(no_op_state_dict)
        out_after_lora_attempt = pipe(**inputs, generator=np.random.default_rng(0))[0]

        denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
        self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
        self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))

        # test only for text encoder
        for lora_module in self.pipeline_class._lora_loadable_modules:
            if "text_encoder" in lora_module:
                text_encoder = getattr(pipe, lora_module)
                if lora_module == "text_encoder":
                    prefix = "text_encoder"
                elif lora_module == "text_encoder_2":
                    prefix = "text_encoder_2"

                logger = logging.get_logger("mindone.diffusers.loaders.lora_base")
                logger.setLevel(logging.WARNING)

                with CaptureLogger(logger) as cap_logger:
                    self.pipeline_class.load_lora_into_text_encoder(
                        no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
                    )

                self.assertTrue(
                    cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
                )

    def test_set_adapters_match_attention_kwargs(self):
        """Test to check if outputs after `set_adapters()` and attention kwargs match."""
        attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)

        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(output_no_lora.shape == self.output_shape)

            pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            lora_scale = 0.5
            attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
            output_lora_scale = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]
            self.assertFalse(
                np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )

            pipe.set_adapters("default", lora_scale)
            output_lora_scale_wo_kwargs = pipe(**inputs, generator=np.random.default_rng(0))[0]
            self.assertTrue(
                not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
                "Lora + scale should change the output",
            )
            self.assertTrue(
                np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
                "Lora + scale should match the output of `set_adapters()`.",
            )

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(
                    save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
                )

                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
                pipe = self.pipeline_class(**components)
                pipe.set_progress_bar_config(disable=None)
                pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))

                for module_name, module in modules_to_save.items():
                    self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")

                output_lora_from_pretrained = pipe(**inputs, generator=np.random.default_rng(0), **attention_kwargs)[0]
                self.assertTrue(
                    not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Lora + scale should change the output",
                )
                self.assertTrue(
                    np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Loading from saved checkpoints should give same results as attention_kwargs.",
                )
                self.assertTrue(
                    np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
                    "Loading from saved checkpoints should give same results as set_adapters().",
                )

    @require_peft_version_greater("0.13.2")
    def test_lora_B_bias(self):
        # Currently, this test is only relevant for Flux Control LoRA as we are not
        # aware of any other LoRA checkpoint that has its `lora_B` biases trained.
        components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)

        # keep track of the bias values of the base layers to perform checks later.
        bias_values = {}
        denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
        for name, module in denoiser.cells_and_names():
            if any(k in name for k in self.denoiser_target_modules):
                if module.bias is not None:
                    bias_values[name] = module.bias.data.clone()

        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        original_output = pipe(**inputs, generator=np.random.default_rng(0))[0]

        denoiser_lora_config.lora_bias = False
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
        lora_bias_false_output = pipe(**inputs, generator=np.random.default_rng(0))[0]
        pipe.delete_adapters("adapter-1")

        denoiser_lora_config.lora_bias = True
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
        lora_bias_true_output = pipe(**inputs, generator=np.random.default_rng(0))[0]

        self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
        self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
        self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))

    def test_correct_lora_configs_with_different_ranks(self):
        components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
        pipe = self.pipeline_class(**components)
        pipe.set_progress_bar_config(disable=None)
        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        original_output = pipe(**inputs, generator=np.random.default_rng(0))[0]

        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")

        lora_output_same_rank = pipe(**inputs, generator=np.random.default_rng(0))[0]

        if self.unet_kwargs is not None:
            pipe.unet.delete_adapters("adapter-1")
        else:
            pipe.transformer.delete_adapters("adapter-1")

        denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
        for name, _ in denoiser.cells_and_names():
            if "to_k" in name and "attn" in name and "lora" not in name:
                module_name_to_rank_update = name.replace(".base_layer.", ".")
                break

        # change the rank_pattern
        updated_rank = denoiser_lora_config.r * 2
        denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}

        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern

        self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})

        lora_output_diff_rank = pipe(**inputs, generator=np.random.default_rng(0))[0]
        self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
        self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))

        if self.unet_kwargs is not None:
            pipe.unet.delete_adapters("adapter-1")
        else:
            pipe.transformer.delete_adapters("adapter-1")

        # similarly change the alpha_pattern
        updated_alpha = denoiser_lora_config.lora_alpha * 2
        denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
        if self.unet_kwargs is not None:
            pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(
                pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
            )
        else:
            pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
            self.assertTrue(
                pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
            )

        lora_output_diff_alpha = pipe(**inputs, generator=np.random.default_rng(0))[0]
        self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
        self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))

    def test_layerwise_casting_inference_denoiser(self):
        from mindone.diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
        from mindone.diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN

        def check_linear_dtype(module, storage_dtype, compute_dtype):
            patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
            if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
                patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
            for name, submodule in module.cells_and_names():
                if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
                    continue
                dtype_to_check = storage_dtype
                if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
                    dtype_to_check = compute_dtype
                if getattr(submodule, "weight", None) is not None:
                    self.assertEqual(submodule.weight.dtype, dtype_to_check)
                if getattr(submodule, "bias", None) is not None:
                    self.assertEqual(submodule.bias.dtype, dtype_to_check)

        def initialize_pipeline(storage_dtype=None, compute_dtype=ms.float32):
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
            pipe = self.pipeline_class(**components)
            pipe = pipe.to(dtype=compute_dtype)
            pipe.set_progress_bar_config(disable=None)

            pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)

            if storage_dtype is not None:
                denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
                check_linear_dtype(denoiser, storage_dtype, compute_dtype)

            return pipe

        _, _, inputs = self.get_dummy_inputs(with_generator=False)

        pipe_fp32 = initialize_pipeline(storage_dtype=None)
        pipe_fp32(**inputs, generator=np.random.default_rng(0))[0]

    def test_inference_load_delete_load_adapters(self):
        "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
        for scheduler_cls in self.scheduler_classes:
            components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
            pipe = self.pipeline_class(**components)
            pipe.set_progress_bar_config(disable=None)
            _, _, inputs = self.get_dummy_inputs(with_generator=False)

            output_no_lora = pipe(**inputs, generator=np.random.default_rng(0))[0]

            if "text_encoder" in self.pipeline_class._lora_loadable_modules:
                pipe.text_encoder.add_adapter(text_lora_config)
                self.assertTrue(
                    check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
                )

            denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
            denoiser.add_adapter(denoiser_lora_config)
            self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

            if self.has_two_text_encoders or self.has_three_text_encoders:
                lora_loadable_components = self.pipeline_class._lora_loadable_modules
                if "text_encoder_2" in lora_loadable_components:
                    pipe.text_encoder_2.add_adapter(text_lora_config)
                    self.assertTrue(
                        check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
                    )

            output_adapter_1 = pipe(**inputs, generator=np.random.default_rng(0))[0]

            with tempfile.TemporaryDirectory() as tmpdirname:
                modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
                lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
                self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
                self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))

                # First, delete adapter and compare.
                pipe.delete_adapters(pipe.get_active_adapters()[0])
                output_no_adapter = pipe(**inputs, generator=np.random.default_rng(0))[0]
                self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
                self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))

                # Then load adapter and compare.
                pipe.load_lora_weights(tmpdirname)
                output_lora_loaded = pipe(**inputs, generator=np.random.default_rng(0))[0]
                self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
